Skip to content

add _native_npu_attention support mask shape like [B,1,1,S]#13490

Merged
yiyixuxu merged 6 commits intohuggingface:mainfrom
chang-zhijie:native_npu
Apr 18, 2026
Merged

add _native_npu_attention support mask shape like [B,1,1,S]#13490
yiyixuxu merged 6 commits intohuggingface:mainfrom
chang-zhijie:native_npu

Conversation

@chang-zhijie
Copy link
Copy Markdown
Contributor

@chang-zhijie chang-zhijie commented Apr 16, 2026

This PR resolves the unsupported atten_mask shape error when running attention with NPU (Ascend) devices.

Problem:
The NPU's fusion attention operator (e.g., npu_fusion_attention) does not support automatic broadcasting for attention masks.
When a mask of shape [batch, seq_len] or [batch, 1, 1, seq_len] is passed, the operator fails with an error similar to:
get unsupported atten_mask shape, the shape is [B, 1, 1, S] – while only shapes like [B, N, S, S], [B, 1, S, S], [1, 1, S, S], or [S, S] are accepted.

Solution:
When running on NPU, explicitly expand the mask to [B, 1, S, S] to satisfy the operator’s shape constraints.

Reference:
Ascend NPU fusion attention API:
https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md

Here is a code for ErnieImage example of using the NPU backend:

python
import torch
import torch_npu
from diffusers import ErnieImagePipeline
from diffusers.utils import load_image

pipe = ErnieImagePipeline.from_pretrained("/model_dir/ERNIE-Image", torch_dtype=torch.bfloat16)
pipe = pipe.to("npu")
pipe.transformer.set_attention_backend("_native_npu")
generator = torch.Generator(device="npu")

prompt = "A black and white Chinese rural dog"
images = pipe(
    prompt=prompt,
    height=1024,
    width=1024,
    num_inference_steps=50,
    guidance_scale=5.0,
    generator=generator,
    use_pe=True,
).images
images[0].save("ernie-image-output.png")

@github-actions github-actions Bot added models size/S PR with diff < 50 LOC labels Apr 16, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i left one feedback, thanks!

Comment thread src/diffusers/models/attention_dispatch.py Outdated
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 16, 2026
@yiyixuxu yiyixuxu added close-to-merge and removed size/S PR with diff < 50 LOC labels Apr 16, 2026
@yiyixuxu
Copy link
Copy Markdown
Collaborator

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

Style fix is beginning .... View the workflow run here.

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions github-actions Bot added the size/S PR with diff < 50 LOC label Apr 17, 2026
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 17, 2026
Copy link
Copy Markdown
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

Style fix is beginning .... View the workflow run here.

@yiyixuxu
Copy link
Copy Markdown
Collaborator

@chang-zhijie can you run make style? once CI passed I'll merge:)

@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 18, 2026
@chang-zhijie
Copy link
Copy Markdown
Contributor Author

@bot /style

@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 18, 2026
@yiyixuxu yiyixuxu merged commit c8c8401 into huggingface:main Apr 18, 2026
13 of 14 checks passed
terarachang pushed a commit to terarachang/diffusers that referenced this pull request Apr 30, 2026
…ace#13490)

* add _native_npu_attention support mask shape like [B,1,1,S]

* add _native_npu_attention support mask shape like [B,1,1,S]

* fix style

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants